package it.uniroma2.exp.dtk;

import it.uniroma2.dsk.DS;
import it.uniroma2.dsk.full.BoundedFullDS;
import it.uniroma2.dtk.dt.DT;
import it.uniroma2.dtk.dt.GenericDT;
import it.uniroma2.dtk.dt.partial.GenericPartialDT;
import it.uniroma2.dtk.dt.route.GenericRouteDT;
import it.uniroma2.dtk.dt.subpath.GenericSubPathDT;
import it.uniroma2.exp.AbstractExperiment;
import it.uniroma2.exp.tools.AvgVarCalculator;
import it.uniroma2.sk.SequenceKernel;
import it.uniroma2.tk.PartialTreeKernel;
import it.uniroma2.tk.RouteTreeKernel;
import it.uniroma2.tk.SubpathTreeKernel;
import it.uniroma2.tk.TreeKernel;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.math.statistical.SpearmanCorrelation;
import it.uniroma2.util.tree.RandomTreeGenerator;
import it.uniroma2.util.tree.Tree;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.PrintStream;
import java.util.Vector;

import org.apache.commons.lang.ArrayUtils;

/**
 * @author Lorenzo Dell'Arciprete
 * 
 * This experiment evaluates how similar are two kernel results for trees taken from a certain corpus.
 * The available corpora are artificially generated trees, question classification trees and
 * textual entrailment recognition trees.
 * Other than the Tree Kernel and Distributed Tree Kernel, a Zero Tree Kernel is included, 
 * that is simply a Tree Kernel with lambda = 0.
 * The kernel value similarity measures are Spearman Correlation and Ratio 
 */
public class KernelComparison extends AbstractExperiment {
	
	public enum OPS { TK, DTK, ZTK };
	public enum TYPES { artificial, qc, rte, reuters };
	//The location of the Question Classification trees file (in SVM Light format)
	public final String inputTreesQC = System.getProperty("input.file.qc", "/home/lorenzo/esperimenti/dtk/qc/correlationTestData/corr_test_data.dat");
	//The location of the Textual Entailment Recognition tree files folder
	private static final String rte_base_dir = System.getProperty("input.dir.rte", "/home/lorenzo/workspace/arte/data/experiments/preprocessing/CPW/");
	
	public final String[] inputTreePairsRTE = {
		rte_base_dir  + "RTE1_dev.xml"
		,rte_base_dir  + "RTE1_test.xml"
		,rte_base_dir  + "RTE2_dev.xml"
		,rte_base_dir  + "RTE2_test.xml"
		,rte_base_dir  + "RTE3_dev.xml"
		,rte_base_dir  + "RTE3_test.xml"
		,rte_base_dir  + "RTE4_test.xml"
		,rte_base_dir  + "RTE5_dev.xml"
		,rte_base_dir  + "RTE5_test.xml"
	};
	
	//The location of the Reuters dataset files folder; each line should contain a sentence enclosed in <BODY>...<BODY> tags
	private static final String reuters_base_dir = System.getProperty("input.dir.reu", "/home/lorenzo/esperimenti/reuters dataset/Trimmed");

	private DT dtk = null;
	private DS<String> dsk = null;
	private SequenceKernel<String> sk = null;
	private int kernel = 0;
	
	//The two kernels to be compared
	private OPS OP1;
	private OPS OP2;
	//The corpora for the trees
	private TYPES TYPE;
	
	public static void main(String [] argv) throws Exception {
		KernelComparison kc = new KernelComparison();
		try {
			kc.setOutputStream(new PrintStream(new File(System.getProperty("output.file", "kernel_comparison.dat"))));
		} catch (FileNotFoundException e) {
			e.printStackTrace();
			System.exit(0);
		}
		if (System.getProperty("vector.sizes") != null) {
			String[] vsString = System.getProperty("vector.sizes").split(",");
			int[] vsArray = new int[vsString.length];
			for (int i=0; i<vsString.length; i++)
				vsArray[i] = Integer.parseInt(vsString[i]);
			kc.setVectorSizeArray(vsArray);
		}
		else
			kc.setVectorSizeArray(new int[] {1024, 2048, 4096, 8192});
		kc.setRandomTreeNodes(30);
		if (System.getProperty("lambdas") != null) {
			String[] lString = System.getProperty("lambdas").split(",");
			double[] lArray = new double[lString.length];
			for (int i=0; i<lString.length; i++)
				lArray[i] = Double.parseDouble(lString[i]);
			kc.setLambdaArray(lArray);
		}
		else
			kc.setLambdaArray(new double[]{0.2, 0.4, 0.6, 0.8, 1});
		//The corpus type, the operations pair and the kernel type
		//Custom parameters are:
		//Corpus type: 0=artificial, 1=qc, 2=rte, 3=reuters
		//Operations pair: 0=DTKvsTK, 1=DTKvsZTK, 2=ZTKvsTK
		//Kernel type: 0=Tree Kernel, 1=Subpath Kernel, 2=Route Kernel, 3=Partial Tree Kernel, 4=Sequence Kernel
		//Mu (for the Partial Tree Kernel): n*0.2
		//Bound on the Sequence Kernel: p
		if (System.getProperty("custom.params") != null) {
			String[] params = System.getProperty("custom.params").split("_");
			int[][] paramsArray = new int[params.length][];
			for (int j=0; j<params.length; j++) {
				String[] param = params[j].split(",");
				paramsArray[j] = new int[param.length];
				for (int i=0; i<param.length; i++)
					paramsArray[j][i] = Integer.parseInt(param[i]);
			}
			kc.setCustomParameters(paramsArray);
		}
		else
			kc.setCustomParameters(new int[][] {{3},{0},{4},{1},{3}});
		kc.runAll();
	}

	@Override
	protected void runExperiment() throws Exception {
		TYPE = TYPES.values()[currentCustomParameters[0]];
		int opPair = currentCustomParameters[1];
		if (opPair == 0) {
			OP1 = OPS.DTK;
			OP2 = OPS.TK;
		}
		else if (opPair == 1) {
			OP1 = OPS.DTK;
			OP2 = OPS.ZTK;
		}
		else if (opPair == 2) {
			OP1 = OPS.ZTK;
			OP2 = OPS.TK;
		} 
		kernel = currentCustomParameters[2];
		if (TYPE == TYPES.reuters &&  kernel != 4) {
			out.println("Only SK allowed for Reuters dataset");
			return;
		}
		double mu = currentCustomParameters[3]*0.2;
		int bound = currentCustomParameters[4];
		usePos = TYPE != TYPES.artificial && kernel != 1;
		if (kernel == 0) {
			TreeKernel.lexicalized = lexicalized;
			TreeKernel.lambda = lambda;
			dtk = new GenericDT(randomOffset, vectorSize, usePos, lexicalized, lambda, compositionType);
		}
		else if (kernel == 1) {
			SubpathTreeKernel.lambda = lambda;
			dtk = new GenericSubPathDT(randomOffset, vectorSize, usePos, lambda, compositionType);
		}
		else if (kernel == 2) {
			RouteTreeKernel.lambda = lambda;
			dtk = new GenericRouteDT(randomOffset, vectorSize, false, lambda, compositionType);
		}
		else if (kernel == 3) {
			//Remember that mu and lambda are swapped in the DTK formulation
			PartialTreeKernel.mu = lambda;
			PartialTreeKernel.lambda = mu;
			dtk = new GenericPartialDT(randomOffset, vectorSize, usePos, lexicalized, lambda, mu, compositionType);
		}
		else if (kernel == 4) {
			sk = new SequenceKernel<String>(lambda);
			dsk = new BoundedFullDS<String>(randomOffset, vectorSize, lambda, bound, compositionType);
		}
		
		double spearmanCorrelation = 0;
		AvgVarCalculator avgRatioCor = new AvgVarCalculator();
		switch(TYPE) {
		case artificial: spearmanCorrelation = computeArtificial(avgRatioCor); break;
		case rte: spearmanCorrelation = computeRTE(avgRatioCor); break;
		case qc: spearmanCorrelation = computeQC(avgRatioCor); break;
		case reuters: spearmanCorrelation = computeReuters(avgRatioCor); break;
		}
		out.println("Spearman\tRatio");
		out.print(String.format("%.3f\t", spearmanCorrelation));
		out.println(String.format("%.3f\t%.3f\t", avgRatioCor.getAvg(), avgRatioCor.getVar()));
	}
	
	private double computeReuters(AvgVarCalculator ratio) throws Exception {
		Vector<Double> a_vec = new Vector<Double>(); 
		Vector<Double> b_vec = new Vector<Double>();
		SpearmanCorrelation corr = new SpearmanCorrelation();
		for (File file : new File(reuters_base_dir).listFiles()) {
			System.out.println("\n"+file);
			int count = 0;
			BufferedReader in = new BufferedReader(new FileReader(file));
			while (true) {
				String treet = in.readLine();
				String treeh = in.readLine();
				if (treet == null || treeh == null)
					break;
				treet = treet.substring(treet.indexOf("<BODY>")+6, treet.lastIndexOf("<BODY>")).trim(); 
				treeh = treeh.substring(treeh.indexOf("<BODY>")+6, treeh.lastIndexOf("<BODY>")).trim(); 
				computePair(treet, treeh, a_vec, b_vec, ratio);
				System.out.print('.');
				count += 2;
				if (count%100 == 0) System.out.println("" +count);
			}
			in.close();
		}
		System.out.println("end");
		float[] a = new float[a_vec.size()];
		for (int i=0; i< a_vec.size();i++) a[i]= a_vec.elementAt(i).floatValue();
		float[] b = new float[b_vec.size()];
		for (int i=0; i< b_vec.size();i++) b[i]= b_vec.elementAt(i).floatValue();
		return corr.spearmanCorrelationCoefficient(a, b);
	}

	public double computeRTE(AvgVarCalculator ratio) throws Exception {
		double[] resultsSpear = new double[inputTreePairsRTE.length];
		for (int n=0; n<inputTreePairsRTE.length; n++) {
			BufferedReader in = new BufferedReader(new FileReader(inputTreePairsRTE[n]));
			SpearmanCorrelation corr = new SpearmanCorrelation();
			Vector<Double> a_vec = new Vector<Double>(); 
			Vector<Double> b_vec = new Vector<Double>();
			int count = 0;
			while (true) {
				String treet = in.readLine();
				if (treet == null)
					break;
				else
					treet = treet.trim();
				if (!treet.startsWith("<t>"))
					continue;
				String treeh = in.readLine().trim();
				if (!treeh.startsWith("<h>"))
					continue;
				treet = treet.substring(3, treet.indexOf("</t>")).trim(); 
				treeh = treeh.substring(3, treeh.indexOf("</h>")).trim(); 
				computePair(treet, treeh, a_vec, b_vec, ratio);
				System.out.print('.');
				if (count%100 == 0) System.out.println("" +count);
				count += 2;
			}
			float[] a = new float[a_vec.size()];
			for (int i=0; i< a_vec.size();i++) a[i]= a_vec.elementAt(i).floatValue();
			float[] b = new float[b_vec.size()];
			for (int i=0; i< b_vec.size();i++) b[i]= b_vec.elementAt(i).floatValue();
			resultsSpear[n] = corr.spearmanCorrelationCoefficient(a, b);
			in.close();
			System.out.println("end");
		}
		double tot = 0;
		for (double d : resultsSpear)
			tot += d;
		return tot/inputTreePairsRTE.length;
	}
	
	public double computeArtificial(AvgVarCalculator ratio) throws Exception {
		RandomTreeGenerator rtg = new RandomTreeGenerator(randomOffset);
		Vector<Double> a_vec = new Vector<Double>(); 
		Vector<Double> b_vec = new Vector<Double>();
		SpearmanCorrelation corr = new SpearmanCorrelation();
		for (int i=0; i<1000; i++) {
			System.out.print('.');
			if (i%100 == 0) System.out.println("" +i);
			String treeh = rtg.generateRandomTree(randomTreeLabels, randomTreeMaxDegree, randomTreeNodes);
			String treet = rtg.generateRandomTree(randomTreeLabels, randomTreeMaxDegree, randomTreeNodes);
			try {
				computePair(treet, treeh, a_vec, b_vec, ratio);
			} catch (Exception e) {
				System.out.println(e.getMessage());
			}
		}
		System.out.println("end");
		float[] a = new float[a_vec.size()];
		for (int i=0; i< a_vec.size();i++) a[i]= a_vec.elementAt(i).floatValue();
		float[] b = new float[b_vec.size()];
		for (int i=0; i< b_vec.size();i++) b[i]= b_vec.elementAt(i).floatValue();
		return corr.spearmanCorrelationCoefficient(a, b);
	}
	
	public double computeQC(AvgVarCalculator ratio) throws Exception {
		Vector<Double> a_vec = new Vector<Double>(); 
		Vector<Double> b_vec = new Vector<Double>();
		SpearmanCorrelation corr = new SpearmanCorrelation();
		BufferedReader in = new BufferedReader(new FileReader(inputTreesQC));
		int count = 0;
		while (true) {
			String treet = in.readLine();
			String treeh = in.readLine();
			if (treet == null || treeh == null)
				break;
			treet = treet.substring(treet.indexOf("|BT|")+4, treet.indexOf("|ET|")).trim(); 
			treeh = treeh.substring(treeh.indexOf("|BT|")+4, treeh.indexOf("|ET|")).trim(); 
			computePair(treet, treeh, a_vec, b_vec, ratio);
			System.out.print('.');
			if (count%100 == 0) System.out.println("" +count);
			count += 2;
		}
		in.close();
		System.out.println("end");
		float[] a = new float[a_vec.size()];
		for (int i=0; i< a_vec.size();i++) a[i]= a_vec.elementAt(i).floatValue();
		float[] b = new float[b_vec.size()];
		for (int i=0; i< b_vec.size();i++) b[i]= b_vec.elementAt(i).floatValue();
		return corr.spearmanCorrelationCoefficient(a, b);
	}
	
	public void computePair(String treet, String treeh, Vector<Double> a_vec, Vector<Double> b_vec, AvgVarCalculator ratio) throws Exception {
		double value1, value2;
		// Extracting kernel arguments
		Object xt;
		Object xh;
		if (TYPE == TYPES.reuters) {
			xt = treet.split(" ");
			xh = treeh.split(" ");
		}
		else {
			xt = Tree.fromPennTree(treet);
			xh = Tree.fromPennTree(treeh);
			if (dsk != null) {
				xt = treeToSentence((Tree)xt);
				xh = treeToSentence((Tree)xh);
			}
		}
		value1 = op(xt, xh, OP1);
		value2 = op(xt, xh, OP2);
		a_vec.add(value1);
		b_vec.add(value2);
		if (value2 != 0)
			ratio.addSample(value1/value2);
	}
	
	public double op(Object xt, Object xh, OPS op) throws Exception {
		switch(op) {
			case TK: return TK(xt, xh); 
			case DTK: return DTK(xt, xh); 
			case ZTK: return ZTK(xt, xh); 
			default: throw new Exception("Unknown operation!");
		}
	}
	
	public double TK(Object xt, Object xh) {
		if (kernel == 0)
			return TreeKernel.value((Tree)xt, (Tree)xh);
		else if (kernel == 1)
			return SubpathTreeKernel.value((Tree)xt, (Tree)xh);
		else if (kernel == 2)
			return RouteTreeKernel.value((Tree)xt, (Tree)xh);
		else if (kernel == 3)
			return PartialTreeKernel.value((Tree)xt, (Tree)xh);
		else if (kernel == 4)
			return sk.value((String[])xt, (String[])xh);
		else return 0;
	}
	
	public double DTK(Object xt, Object xh) throws Exception {
		double[] vectort, vectorh;
		if (dtk != null) {
			vectort = dtk.dt((Tree)xt);
			vectorh = dtk.dt((Tree)xh);
		}
		else if (dsk != null) {
			vectort = dsk.ds((String[])xt);
			vectorh = dsk.ds((String[])xh);
		}
		else
			throw new Exception("Distributed kernel not initialized");
		return ArrayMath.dot(vectort,vectorh);
	}
	
	public double ZTK(Object xt, Object xh) throws Exception {
		double result = 0;
		if (kernel == 0) {
			double oldLambda = TreeKernel.lambda;
			TreeKernel.lambda = 0;
			result = TK(xt, xh);
			TreeKernel.lambda = oldLambda;
		}
		else if (kernel == 1) {
			double oldLambda = SubpathTreeKernel.lambda;
			SubpathTreeKernel.lambda = 0;
			result = TK(xt, xh);
			SubpathTreeKernel.lambda = oldLambda;
		}
		else if (kernel == 2) {
			double oldLambda = RouteTreeKernel.lambda;
			RouteTreeKernel.lambda = 0;
			result = TK(xt, xh);
			RouteTreeKernel.lambda = oldLambda;
		}
		else
			throw new Exception("ZTK not supported");
		return result;
	}
	
	private String[] treeToSentence(Tree t) {
		if (t.isTerminal())
			return new String[] {t.getRootLabel()};
		else {
			String[] res = treeToSentence(t.getChildren().get(0));
			for (int i=1; i<t.getChildren().size(); i++)
				res = (String[]) ArrayUtils.addAll(res, treeToSentence(t.getChildren().get(i)));
			return res;
		}
	}
}
